""" Furthest point sampling
Original author: Haoqiang Fan
Modified by Charles R. Qi
All Rights Reserved. 2017.
"""
import tensorflow as tf
from tensorflow.python.framework import ops
import sys
import os

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sampling_module = tf.load_op_library(
    os.path.join(BASE_DIR, "build", "libtf_sampling.so")
)


def prob_sample(inp, inpr):
    """
input:
    batch_size * ncategory float32
    batch_size * npoints   float32
returns:
    batch_size * npoints   int32
    """
    return sampling_module.prob_sample(inp, inpr)


ops.NoGradient("ProbSample")


# TF1.0 API requires set shape in C++
# @tf.RegisterShape('ProbSample')
# def _prob_sample_shape(op):
#    shape1=op.inputs[0].get_shape().with_rank(2)
#    shape2=op.inputs[1].get_shape().with_rank(2)
#    return [tf.TensorShape([shape2.dims[0],shape2.dims[1]])]
def gather_point(inp, idx):
    """
input:
    batch_size * ndataset * 3   float32
    batch_size * npoints        int32
returns:
    batch_size * npoints * 3    float32
    """
    return sampling_module.gather_point(inp, idx)


# @tf.RegisterShape('GatherPoint')
# def _gather_point_shape(op):
#    shape1=op.inputs[0].get_shape().with_rank(3)
#    shape2=op.inputs[1].get_shape().with_rank(2)
#    return [tf.TensorShape([shape1.dims[0],shape2.dims[1],shape1.dims[2]])]
@tf.RegisterGradient("GatherPoint")
def _gather_point_grad(op, out_g):
    inp = op.inputs[0]
    idx = op.inputs[1]
    return [sampling_module.gather_point_grad(inp, idx, out_g), None]


def farthest_point_sample(npoint, inp):
    """
input:
    int32
    batch_size * ndataset * 3   float32
returns:
    batch_size * npoint         int32
    """
    return sampling_module.farthest_point_sample(inp, npoint)


ops.NoGradient("FarthestPointSample")
